{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8a252366",
   "metadata": {},
   "source": [
    "# Early Stopping\n",
    "The following example shows how to use early stopping while training a model and how to create checkpoints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f7d979c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n",
    "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel('ERROR')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dafa4c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ampligraph\n",
    "# Benchmark datasets are under ampligraph.datasets module\n",
    "from ampligraph.datasets import load_fb15k_237\n",
    "# load fb15k-237 dataset\n",
    "dataset = load_fb15k_237()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "87ed1243",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/100\n",
      "29/29 [==============================] - 3s 94ms/step - loss: 6622.9019\n",
      "Epoch 2/100\n",
      "29/29 [==============================] - ETA: 0s - loss: 6493.6064"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-08 13:21:22.984591: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3/3 [==============================] - 1s 262ms/step\n",
      "29/29 [==============================] - 2s 81ms/step - loss: 6493.6064 - val_mrr: 0.0684 - val_mr: 3258.6080 - val_hits@1: 0.0000e+00 - val_hits@10: 0.1790 - val_hits@100: 0.3466\n",
      "Epoch 3/100\n",
      "29/29 [==============================] - 1s 34ms/step - loss: 6365.1802\n",
      "Epoch 4/100\n",
      "29/29 [==============================] - ETA: 0s - loss: 6234.0171"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-08 13:21:26.191442: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3/3 [==============================] - 1s 219ms/step\n",
      "29/29 [==============================] - 2s 69ms/step - loss: 6234.0171 - val_mrr: 0.0809 - val_mr: 1924.0540 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2102 - val_hits@100: 0.4006\n",
      "Epoch 5/100\n",
      "29/29 [==============================] - 2s 59ms/step - loss: 6101.5239\n",
      "Epoch 6/100\n",
      "26/29 [=========================>....] - ETA: 0s - loss: 6003.9404"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-08 13:21:29.466724: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3/3 [==============================] - 1s 319ms/step\n",
      "29/29 [==============================] - 2s 55ms/step - loss: 5966.6636 - val_mrr: 0.0777 - val_mr: 1346.6420 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2188 - val_hits@100: 0.4347\n",
      "Epoch 7/100\n",
      "29/29 [==============================] - 1s 27ms/step - loss: 5829.4419\n",
      "Epoch 8/100\n",
      "29/29 [==============================] - ETA: 0s - loss: 5690.5356"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-08 13:21:32.519336: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3/3 [==============================] - 1s 215ms/step\n",
      "29/29 [==============================] - 2s 72ms/step - loss: 5690.5356 - val_mrr: 0.0793 - val_mr: 1060.0114 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2216 - val_hits@100: 0.4659\n",
      "Epoch 9/100\n",
      "29/29 [==============================] - 2s 56ms/step - loss: 5553.0103\n",
      "Epoch 10/100\n",
      "28/29 [===========================>..] - ETA: 0s - loss: 5434.2549"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-08 13:21:35.457800: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3/3 [==============================] - 1s 235ms/step\n",
      "29/29 [==============================] - 1s 46ms/step - loss: 5418.0444 - val_mrr: 0.0837 - val_mr: 901.1193 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2216 - val_hits@100: 0.4830\n",
      "Epoch 11/100\n",
      "29/29 [==============================] - 1s 26ms/step - loss: 5286.5752\n",
      "Epoch 12/100\n",
      "      2/Unknown - 0s 159ms/step===>..] - ETA: 0s - loss: 5171.8193"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-08 13:21:39.521803: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3/3 [==============================] - 0s 147ms/step\n",
      "29/29 [==============================] - 3s 102ms/step - loss: 5158.7935 - val_mrr: 0.0832 - val_mr: 841.9688 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2216 - val_hits@100: 0.4858\n",
      "Epoch 13/100\n",
      "29/29 [==============================] - 0s 12ms/step - loss: 5035.1396\n",
      "Epoch 14/100\n",
      "3/3 [==============================] - 0s 107ms/steposs: 4926.96\n",
      "29/29 [==============================] - 1s 24ms/step - loss: 4916.2104 - val_mrr: 0.0839 - val_mr: 815.3011 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2330 - val_hits@100: 0.4972\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-08 13:21:40.702638: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15/100\n",
      "29/29 [==============================] - 1s 26ms/step - loss: 4801.8057\n",
      "Epoch 16/100\n",
      "29/29 [==============================] - ETA: 0s - loss: 4691.9980"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-08 13:21:42.851130: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3/3 [==============================] - 1s 249ms/step\n",
      "29/29 [==============================] - 2s 61ms/step - loss: 4691.9980 - val_mrr: 0.0870 - val_mr: 777.9716 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2301 - val_hits@100: 0.5114\n",
      "Epoch 17/100\n",
      "29/29 [==============================] - 1s 20ms/step - loss: 4587.6689\n",
      "Epoch 18/100\n",
      "27/29 [==========================>...] - ETA: 0s - loss: 4499.3130"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-08 13:21:45.612853: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3/3 [==============================] - 1s 218ms/step\n",
      "29/29 [==============================] - 2s 69ms/step - loss: 4488.2705 - val_mrr: 0.0833 - val_mr: 755.6392 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2216 - val_hits@100: 0.5199\n",
      "Epoch 19/100\n",
      "29/29 [==============================] - 1s 29ms/step - loss: 4393.5967\n",
      "Epoch 20/100\n",
      "29/29 [==============================] - ETA: 0s - loss: 4303.0571"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-08 13:21:47.841319: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3/3 [==============================] - 1s 234ms/step\n",
      "29/29 [==============================] - 1s 48ms/step - loss: 4303.0571 - val_mrr: 0.0827 - val_mr: 739.2500 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2216 - val_hits@100: 0.5227\n",
      "Epoch 21/100\n",
      "29/29 [==============================] - 0s 17ms/step - loss: 4217.0073\n",
      "Epoch 22/100\n",
      "      2/Unknown - 0s 152ms/step===>..] - ETA: 0s - loss: 4140.2354"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-02-08 13:21:50.470017: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedFromRowLengths/RowPartitionFromRowLengths/assert_non_negative/assert_less_equal/Assert/AssertGuard/branch_executed/_8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3/3 [==============================] - 0s 165ms/step\n",
      "29/29 [==============================] - 2s 64ms/step - loss: 4134.3545 - val_mrr: 0.0839 - val_mr: 737.0227 - val_hits@1: 0.0000e+00 - val_hits@10: 0.2244 - val_hits@100: 0.5170\n",
      "Restoring model weights from the end of the best epoch: 16.\n",
      "Epoch 22: early stopping\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x28a6d6da0>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Import the KGE model\n",
    "from ampligraph.latent_features import ScoringBasedEmbeddingModel\n",
    "\n",
    "# create the model with transe scoring function\n",
    "model = ScoringBasedEmbeddingModel(eta=1, \n",
    "                                     k=10,\n",
    "                                     scoring_type='TransE')\n",
    "\n",
    "\n",
    "# compile the model with loss and optimizer\n",
    "model.compile(optimizer='adam', loss='multiclass_nll')\n",
    "\n",
    "# Use this for checkpoints at regular intervals\n",
    "#checkpoint = tf.keras.callbacks.ModelCheckpoint('./chkpt_transe', monitor='val_mrr', verbose=1, save_best_only=True, mode='min')\n",
    "\n",
    "# Use this for early stopping\n",
    "checkpoint = tf.keras.callbacks.EarlyStopping(monitor=\"val_mrr\",            # which metrics to monitor\n",
    "                                              patience=3,                   # If the monitored metric doesnt improve \n",
    "                                                                            # for these many checks the model early stops\n",
    "                                              verbose=1,                    # verbosity\n",
    "                                              mode=\"max\",                   # how to compare the monitored metrics. \n",
    "                                                                            # max - means higher is better\n",
    "                                              restore_best_weights=True)    # restore the weights with best value\n",
    "\n",
    "dataset = load_fb15k_237()\n",
    "\n",
    "model.fit(dataset['train'],\n",
    "          batch_size=10000,\n",
    "          epochs=100,\n",
    "          validation_freq=2,                            # Epochs to elapse before next evaluation\n",
    "          validation_batch_size=100,                    \n",
    "          validation_burn_in=1,                         # Epoch to start the validation process\n",
    "          validation_data = dataset['valid'][::100],\n",
    "          callbacks=[checkpoint])                       # Pass the callback to the fit function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e5b4dc3b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "206/206 [==============================] - 616s 3s/step\n",
      "MR: 600.0142626480086\n",
      "MRR: 0.18888812773966743\n",
      "hits@1: 0.1267491926803014\n",
      "hits@10: 0.3130687934240141\n"
     ]
    }
   ],
   "source": [
    "# evaluate on the test set\n",
    "ranks = model.evaluate(dataset['test'], # test set\n",
    "                       batch_size=100, # evaluation batch size\n",
    "                       corrupt_side='s,o', \n",
    "                       use_filter={'train':dataset['train'], # Filter to be used for evaluation\n",
    "                                   'valid':dataset['valid'],\n",
    "                                   'test':dataset['test']}\n",
    "                       )\n",
    "\n",
    "# import the evaluation metrics\n",
    "from ampligraph.evaluation.metrics import mrr_score, hits_at_n_score, mr_score\n",
    "\n",
    "print('MR:', mr_score(ranks))\n",
    "print('MRR:', mrr_score(ranks))\n",
    "print('hits@1:', hits_at_n_score(ranks, 1))\n",
    "print('hits@10:', hits_at_n_score(ranks, 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fee12a7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# if using ModelCheckpoint then we can restore the checkpoints using restore model\n",
    "# from ampligraph.utils import restore_model\n",
    "# model = restore_model('chkpt_transe')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "2e69f3670cdad0193847aaa0b77be56c05c951fcbdd384ff882dde0464f4de76"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
